// Copyright 2017 Google LLC. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//    http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package main

import (
	"fmt"
	"io/ioutil"
	"os"
	"os/exec"
	"path"
	"path/filepath"
	"regexp"
	"runtime"
	"sort"
	"strings"

	"github.com/google/gnostic/compiler"
	"github.com/google/gnostic/jsonschema"
	"github.com/google/gnostic/printer"
)

var protoOptionsForExtensions = []ProtoOption{
	{
		Name:  "java_multiple_files",
		Value: "true",
		Comment: "// This option lets the proto compiler generate Java code inside the package\n" +
			"// name (see below) instead of inside an outer class. It creates a simpler\n" +
			"// developer experience by reducing one-level of name nesting and be\n" +
			"// consistent with most programming languages that don't support outer classes.",
	},

	{
		Name:  "java_outer_classname",
		Value: "VendorExtensionProto",
		Comment: "// The Java outer classname should be the filename in UpperCamelCase. This\n" +
			"// class is only used to hold proto descriptor, so developers don't need to\n" +
			"// work with it directly.",
	},
}

const additionalCompilerCodeWithMain = "" +
	"func handleExtension(extensionName string, yamlInput string) (bool, proto.Message, error) {\n" +
	"      switch extensionName {\n" +
	"      // All supported extensions\n" +
	"      %s\n" +
	"      default:\n" +
	"        return false, nil, nil\n" +
	"       }\n" +
	"}\n" +
	"\n" +
	"func main() {\n" +
	"	gnostic_extension_v1.Main(handleExtension)\n" +
	"}\n"

const caseStringForObjectTypes = "\n" +
	"case \"%s\":\n" +
	"var info yaml.Node\n" +
	"err := yaml.Unmarshal([]byte(yamlInput), &info)\n" +
	"if err != nil {\n" +
	"  return true, nil, err\n" +
	"}\n" +
	"info = *info.Content[0]\n" +
	"newObject, err := %s.New%s(&info, compiler.NewContext(\"$root\", &info, nil))\n" +
	"return true, newObject, err"

const caseStringForWrapperTypes = "\n" +
	"case \"%s\":\n" +
	"var info yaml.Node\n" +
	"err := yaml.Unmarshal([]byte(yamlInput), &info)\n" +
	"if err != nil {\n" +
	"  return true, nil, err\n" +
	"}\n" +
	"v, ok := compiler.%sForScalarNode(&info)\n" +
	"if !ok {\n" +
	"	return true, nil, nil\n" +
	"}\n" +
	"newObject := &wrapperspb.%s{Value: v}\n" +
	"return true, newObject, nil"

// generateMainFile generates the main program for an extension.
func generateMainFile(packageName string, license string, codeBody string, imports []string) string {
	code := &printer.Code{}
	code.Print(license)
	code.Print("// THIS FILE IS AUTOMATICALLY GENERATED.\n")

	// generate package declaration
	code.Print("package %s\n", packageName)

	code.Print("import (")
	for _, filename := range imports {
		code.Print("\"" + filename + "\"")
	}
	code.Print(")\n")

	code.Print(codeBody)
	return code.String()
}

func getBaseFileNameWithoutExt(filePath string) string {
	tmp := filepath.Base(filePath)
	return tmp[0 : len(tmp)-len(filepath.Ext(tmp))]
}

func toProtoPackageName(input string) string {
	out := ""
	nonAlphaNumeric := regexp.MustCompile("[^0-9A-Za-z_]+")
	input = nonAlphaNumeric.ReplaceAllString(input, "")
	for index, character := range input {
		if character >= 'A' && character <= 'Z' {
			if index > 0 && input[index-1] != '_' {
				out += "_"
			}
			out += string(character - 'A' + 'a')
		} else {
			out += string(character)
		}
	}
	return out
}

type primitiveTypeInfo struct {
	goTypeName       string
	wrapperProtoName string
}

var supportedPrimitiveTypeInfos = map[string]primitiveTypeInfo{
	"string":  {goTypeName: "String", wrapperProtoName: "StringValue"},
	"number":  {goTypeName: "Float", wrapperProtoName: "DoubleValue"},
	"integer": {goTypeName: "Int", wrapperProtoName: "Int64Value"},
	"boolean": {goTypeName: "Bool", wrapperProtoName: "BoolValue"},
	// TODO: Investigate how to support arrays. For now users will not be allowed to
	// create extension handlers for arrays and they will have to use the
	// plane yaml string as is.
}

type generatedTypeInfo struct {
	schemaName string
	// if this is not nil, the schema should be treataed as a primitive type.
	optionalPrimitiveTypeInfo *primitiveTypeInfo
}

// generateExtension generates the implementation of an extension.
func generateExtension(schemaFile string, outDir string) error {
	outFileBaseName := getBaseFileNameWithoutExt(schemaFile)
	extensionNameWithoutXDashPrefix := outFileBaseName[len("x-"):]
	outDir = path.Join(outDir, "gnostic-x-"+extensionNameWithoutXDashPrefix)
	protoPackage := toProtoPackageName(extensionNameWithoutXDashPrefix)
	protoPackageName := strings.ToLower(protoPackage)
	goPackageName := protoPackageName

	protoOutDirectory := outDir + "/" + "proto"
	var err error

	baseSchema, err := jsonschema.NewBaseSchema()
	if err != nil {
		return err
	}
	baseSchema.ResolveRefs()
	baseSchema.ResolveAllOfs()

	openapiSchema, err := jsonschema.NewSchemaFromFile(schemaFile)
	if err != nil {
		return err
	}
	openapiSchema.ResolveRefs()
	openapiSchema.ResolveAllOfs()

	// build a simplified model of the types described by the schema
	cc := NewDomain(openapiSchema, "v2") // TODO fix for OpenAPI v3

	// create a type for each object defined in the schema
	extensionNameToMessageName := make(map[string]generatedTypeInfo)
	schemaErrors := make([]error, 0)
	supportedPrimitives := make([]string, 0)
	for key := range supportedPrimitiveTypeInfos {
		supportedPrimitives = append(supportedPrimitives, key)
	}
	sort.Strings(supportedPrimitives)
	if cc.Schema.Definitions != nil {
		for _, pair := range *(cc.Schema.Definitions) {
			definitionName := pair.Name
			definitionSchema := pair.Value
			// ensure the id field is set
			if definitionSchema.ID == nil || len(*(definitionSchema.ID)) == 0 {
				schemaErrors = append(schemaErrors,
					fmt.Errorf("schema %s has no 'id' field, which must match the "+
						"name of the OpenAPI extension that the schema represents",
						definitionName))
			} else {
				if _, ok := extensionNameToMessageName[*(definitionSchema.ID)]; ok {
					schemaErrors = append(schemaErrors,
						fmt.Errorf("schema %s and %s have the same 'id' field value",
							definitionName, extensionNameToMessageName[*(definitionSchema.ID)].schemaName))
				} else if (definitionSchema.Type == nil) || (*definitionSchema.Type.String == "object") {
					extensionNameToMessageName[*(definitionSchema.ID)] = generatedTypeInfo{schemaName: definitionName}
				} else {
					// this is a primitive type
					if val, ok := supportedPrimitiveTypeInfos[*definitionSchema.Type.String]; ok {
						extensionNameToMessageName[*(definitionSchema.ID)] = generatedTypeInfo{schemaName: definitionName, optionalPrimitiveTypeInfo: &val}
					} else {
						schemaErrors = append(schemaErrors,
							fmt.Errorf("Schema %s has type '%s' which is "+
								"not supported. Supported primitive types are "+
								"%s.\n", definitionName,
								*definitionSchema.Type.String,
								supportedPrimitives))
					}
				}
			}
			typeName := cc.TypeNameForStub(definitionName)
			typeModel := cc.BuildTypeForDefinition(typeName, definitionName, definitionSchema)
			if typeModel != nil {
				cc.TypeModels[typeName] = typeModel
			}
		}
	}
	if len(schemaErrors) > 0 {
		// error has been reported.
		return compiler.NewErrorGroupOrNil(schemaErrors)
	}

	err = os.MkdirAll(outDir, os.ModePerm)
	if err != nil {
		return err
	}

	err = os.MkdirAll(protoOutDirectory, os.ModePerm)
	if err != nil {
		return err
	}

	// generate the protocol buffer description
	protoOptions := append(protoOptionsForExtensions,
		ProtoOption{Name: "java_package", Value: "org.openapi.extension." + strings.ToLower(protoPackage), Comment: "// The Java package name must be proto package name with proper prefix."},
		ProtoOption{
			Name: "objc_class_prefix", Value: strings.ToLower(protoPackage),
			Comment: "// A reasonable prefix for the Objective-C symbols generated from the package.\n" +
				"// It should at a minimum be 3 characters long, all uppercase, and convention\n" +
				"// is to use an abbreviation of the package name. Something short, but\n" +
				"// hopefully unique enough to not conflict with things that may come along in\n" +
				"// the future. 'GPB' is reserved for the protocol buffer implementation itself.",
		},
		ProtoOption{
			Name:    "go_package",
			Value:   "./;" + strings.ToLower(protoPackage),
			Comment: "// The Go package path.",
		},
	)

	proto := cc.generateProto(protoPackageName, License, protoOptions, nil)
	protoFilename := path.Join(protoOutDirectory, outFileBaseName+".proto")

	err = ioutil.WriteFile(protoFilename, []byte(proto), 0644)
	if err != nil {
		return err
	}

	// generate the compiler
	compiler := cc.GenerateCompiler(goPackageName, License, []string{
		"fmt",
		"regexp",
		"strings",
		"github.com/google/gnostic/compiler",
		"go.yaml.in/yaml/v3",
	})
	goFilename := path.Join(protoOutDirectory, outFileBaseName+".go")
	err = ioutil.WriteFile(goFilename, []byte(compiler), 0644)
	if err != nil {
		return err
	}
	err = exec.Command(runtime.GOROOT()+"/bin/gofmt", "-w", goFilename).Run()
	if err != nil {
		return err
	}

	// generate the main file.

	// TODO: This path is currently fixed to the location of the samples.
	//       Can we make it relative, perhaps with an option or by generating
	//       a go.mod file for the generated extension handler?
	outDirRelativeToPackageRoot := "github.com/google/gnostic/extensions/sample/" + outDir

	var extensionNameKeys []string
	for k := range extensionNameToMessageName {
		extensionNameKeys = append(extensionNameKeys, k)
	}
	sort.Strings(extensionNameKeys)

	wrapperTypeIncluded := false
	var cases string
	for _, extensionName := range extensionNameKeys {
		if extensionNameToMessageName[extensionName].optionalPrimitiveTypeInfo == nil {
			cases += fmt.Sprintf(caseStringForObjectTypes,
				extensionName,
				goPackageName,
				extensionNameToMessageName[extensionName].schemaName)
		} else {
			wrapperTypeIncluded = true
			cases += fmt.Sprintf(caseStringForWrapperTypes,
				extensionName,
				extensionNameToMessageName[extensionName].optionalPrimitiveTypeInfo.goTypeName,
				extensionNameToMessageName[extensionName].optionalPrimitiveTypeInfo.wrapperProtoName)
		}
	}
	extMainCode := fmt.Sprintf(additionalCompilerCodeWithMain, cases)
	imports := []string{
		"github.com/google/gnostic/extensions",
		"github.com/google/gnostic/compiler",
		"google.golang.org/protobuf/proto",
		"go.yaml.in/yaml/v3",
		outDirRelativeToPackageRoot + "/" + "proto",
	}
	if wrapperTypeIncluded {
		imports = append(imports, "google.golang.org/protobuf/types/known/wrapperspb")
	}
	main := generateMainFile("main", License, extMainCode, imports)
	mainFileName := path.Join(outDir, "main.go")
	err = ioutil.WriteFile(mainFileName, []byte(main), 0644)
	if err != nil {
		return err
	}

	// format the compiler
	return exec.Command(runtime.GOROOT()+"/bin/gofmt", "-w", mainFileName).Run()
}

func generateExtensions() error {
	outDir := ""
	schemaFile := ""

	extParamRegex, _ := regexp.Compile("--(.+)=(.+)")

	for i, arg := range os.Args {
		if i == 0 {
			continue // skip the tool name
		}
		var m [][]byte
		if m = extParamRegex.FindSubmatch([]byte(arg)); m != nil {
			flagName := string(m[1])
			flagValue := string(m[2])
			switch flagName {
			case "out_dir":
				outDir = flagValue
			default:
				fmt.Printf("Unknown option: %s.\n%s\n", arg, usage())
				os.Exit(-1)
			}
		} else if arg == "--extension" {
			continue
		} else if arg[0] == '-' {
			fmt.Printf("Unknown option: %s.\n%s\n", arg, usage())
			os.Exit(-1)
		} else {
			schemaFile = arg
		}
	}

	if schemaFile == "" {
		fmt.Printf("No input json schema specified.\n%s\n", usage())
		os.Exit(-1)
	}
	if outDir == "" {
		fmt.Printf("Missing output directive.\n%s\n", usage())
		os.Exit(-1)
	}
	if !strings.HasPrefix(getBaseFileNameWithoutExt(schemaFile), "x-") {
		fmt.Printf("Schema file name has to start with 'x-'.\n%s\n", usage())
		os.Exit(-1)
	}

	return generateExtension(schemaFile, outDir)
}
